import math
from math import sqrt
import argparse
from pathlib import Path
from unittest import TestCase

# torch

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes and utils

from dalle_pytorch import distributed_utils
# from dalle_pytorch import DiscreteVAE
from dalle_pytorch.dalle_pytorch_ori import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_oriema import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_ae import DiscreteVAE

# argument parsing

import sys
sys.path.insert(0, '/home/tiangel/DALLE_3D/Learning-to-Group')
from IPython import embed
import glob
# from pytorch3d.io import load_ply
from pytorch3d.io import load_ply, save_ply
from torch.utils.data import Dataset
import os
from partnet.utils.torch_pc import normalize_points as normalize_points_torch

from pytorch3d.io import IO
from pytorch3d.structures import Pointclouds
# from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras, 
    PointsRasterizationSettings,
    PointsRenderer,
    PulsarPointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor
)
import matplotlib.pyplot as plt
import numpy as np
from geometry_utils import render_pts, rotate_pts, render_pts_with_label
import h5py

parser = argparse.ArgumentParser()

parser.add_argument('--image_folder', type = str, required = True,
                    help='path to your folder of images for learning the discrete VAE and its codebook')

parser.add_argument('--image_size', type = int, required = False, default = 128,
                    help='image size')

parser = distributed_utils.wrap_arg_parser(parser)


train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--vae_path', type=str,
                   help='path to your trained discrete VAE')

train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs')

train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size')

train_group.add_argument('--learning_rate', type = float, default = 1e-3, help = 'learning rate')

train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay')

train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature')

train_group.add_argument('--temp_min', type = float, default = 0.5, help = 'minimum temperature to anneal to')

train_group.add_argument('--anneal_rate', type = float, default = 1e-6, help = 'temperature annealing rate')

train_group.add_argument('--num_images_save', type = int, default = 2, help = 'number of images to save')

model_group = parser.add_argument_group('Model settings')

model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens')

model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)')

model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks')

model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true')

model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension')

model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension')

model_group.add_argument('--dim1', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--dim2', type = int, default = 32, help = 'hidden dimension')

model_group.add_argument('--final_points', type = int, default = 16, help = 'hidden dimension')

model_group.add_argument('--radius', type = float, default = 0.3, help = 'hidden dimension')

model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')

model_group.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

model_group.add_argument('--aug', type = bool, default = True, help = 'KL loss weight')

model_group.add_argument('--testae', type = bool, default = False, help = 'KL loss weight')

args = parser.parse_args()

# constants

IMAGE_SIZE = args.image_size
IMAGE_PATH = args.image_folder

EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
LEARNING_RATE = args.learning_rate
LR_DECAY_RATE = args.lr_decay_rate

NUM_TOKENS = args.num_tokens
NUM_LAYERS = args.num_layers
NUM_RESNET_BLOCKS = args.num_resnet_blocks
SMOOTH_L1_LOSS = args.smooth_l1_loss
EMB_DIM = args.emb_dim
HIDDEN_DIM = args.hidden_dim
KL_LOSS_WEIGHT = args.kl_loss_weight

STARTING_TEMP = args.starting_temp
TEMP_MIN = args.temp_min
ANNEAL_RATE = args.anneal_rate

NUM_IMAGES_SAVE = args.num_images_save

# initialize distributed backend

distr_backend = distributed_utils.set_backend_from_args(args)
distr_backend.initialize()

using_deepspeed = \
    distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)

# data

# class PC_Dataset(Dataset):
#     def __init__(self, path):
#         self.data_dir = path
#         self.data_list = glob.glob(os.path.join(self.data_dir, '*.ply'))
#         self.len = len(self.data_list)
#         self.do_aug = args.aug

#     def __getitem__(self, index):
#         pc = load_ply(self.data_list[index])
#         points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
#         if self.do_aug:
#             scale = points.new(1).uniform_(0.9, 1.05)
#             points[:, 0:3] *= scale
#         return (points, pc[1])

#     def __len__(self):
#         return self.len

# ds = PC_Dataset(IMAGE_PATH)

class PC_Dataset_h5(Dataset):
    def __init__(self, path):
        f = h5py.File(os.path.join('/home/tiangel/datasets/',path), 'r')
        self.data = np.array(f['data'])
        self.len = self.data.shape[0]
        self.do_aug = args.aug

    def __getitem__(self, index):
        # pc = load_ply(self.data_list[index])
        pc = torch.Tensor(self.data[index]).unsqueeze(0)
        points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
        if self.do_aug:
            scale = points.new(1).uniform_(0.9, 1.05)
            points[:, 0:3] *= scale
        return points

    def __len__(self):
        return self.len
# ds = PC_Dataset_h5('shapenet_plys_2048_knowninfo.h5')
ds = PC_Dataset_h5('shapenet_plys_2048_chair.h5')

if distributed_utils.using_backend(distributed_utils.HorovodBackend):
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds, num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
else:
    data_sampler = None

dl = DataLoader(ds, BATCH_SIZE, shuffle = False, drop_last=True)

# loaded_obj = torch.load(args.vae_path)
loaded_obj = torch.load(os.path.join('./outputs/vae_models',args.vae_path))
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(
    **vae_params,
)

# keys = list(weights.keys())
# for k in keys:
#     weights['.'.join(k.split('.')[1:])] = weights[k]
#     weights.pop(k)


vae.load_state_dict(weights)
vae.eval().cuda()

# vae_params = dict(
#     image_size = IMAGE_SIZE,
#     num_layers = NUM_LAYERS,
#     num_tokens = NUM_TOKENS,
#     codebook_dim = EMB_DIM,
#     hidden_dim   = HIDDEN_DIM,
#     num_resnet_blocks = NUM_RESNET_BLOCKS,
#     dim1 = args.dim1,
#     dim2 = args.dim2,
#     radius = args.radius
# )

# vae = DiscreteVAE(
#     **vae_params,
#     smooth_l1_loss = SMOOTH_L1_LOSS,
#     kl_div_loss_weight = KL_LOSS_WEIGHT
# )
# if not using_deepspeed:
#     vae = vae.cuda()


assert len(ds) > 0, 'folder does not contain any images'
if distr_backend.is_root_worker():
    print(f'{len(ds)} images found for training')

# optimizer

opt = Adam(vae.parameters(), lr = LEARNING_RATE)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = opt, T_max = EPOCHS*int(len(ds)/BATCH_SIZE))
# sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE)


if distr_backend.is_root_worker():
    # weights & biases experiment tracking

    import wandb

    model_config = dict(
        num_tokens = NUM_TOKENS,
        smooth_l1_loss = SMOOTH_L1_LOSS,
        num_resnet_blocks = NUM_RESNET_BLOCKS,
        kl_loss_weight = KL_LOSS_WEIGHT
    )

    run = wandb.init(
        project = 'dalle_train_vae',
        job_type = 'train_model',
        config = model_config
    )

# distribute

distr_backend.check_batch_size(BATCH_SIZE)
deepspeed_config = {'train_batch_size': BATCH_SIZE}

(distr_vae, distr_opt, distr_dl, distr_sched) = distr_backend.distribute(
    args=args,
    model=vae,
    optimizer=opt,
    model_parameters=vae.parameters(),
    training_data=ds if using_deepspeed else dl,
    lr_scheduler=sched if not using_deepspeed else None,
    config_params=deepspeed_config,
)

using_deepspeed_sched = False
# Prefer scheduler in `deepspeed_config`.
if distr_sched is None:
    distr_sched = sched
elif using_deepspeed:
    # We are using a DeepSpeed LR scheduler and want to let DeepSpeed
    # handle its scheduling.
    using_deepspeed_sched = True


def save_model(path):
    save_obj = {
        'hparams': vae_params,
    }
    if using_deepspeed:
        cp_path = Path(path)
        path_sans_extension = cp_path.parent / cp_path.stem
        cp_dir = str(path_sans_extension) + '-ds-cp'

        distr_vae.save_checkpoint(cp_dir, client_state=save_obj)
        # We do not return so we do get a "normal" checkpoint to refer to.

    if not distr_backend.is_root_worker():
        return

    save_obj = {
        **save_obj,
        'weights': vae.state_dict()
    }

    torch.save(save_obj, path)

# starting temperature

def render_pytorch3d(renderer, pts, count, name):
    rgb=torch.zeros(pts.shape).cuda()
    rgb[:,1]=0.5
    point_cloud = Pointclouds(points=[pts], features=[rgb])


    rendered_img = renderer(point_cloud, gamma=(1e-4,))
    rendered_img[rendered_img == 0] = 1
    plt.figure(figsize=(10, 10))
    plt.imshow(rendered_img[0, ..., :3].detach().cpu().numpy())
    plt.axis("off");
    plt.savefig(os.path.join(save_dir, '%04d'%count+'_'+name+'.png'),dpi=300)
    plt.close()

global_step = 0
temp = STARTING_TEMP
save_dir = os.path.join('./outputs/vae_outputs','test'+args.save_name)
if not os.path.exists(save_dir):
    os.mkdir(save_dir)
count = 0
cd_loss_list = []
emd_loss_list = []
vq_loss_list = []
perplexity_list = []

R, T = look_at_view_transform(50, 10, 45)
cameras = FoVOrthographicCameras(R=R, T=T, znear=0.01).cuda()
raster_settings = PointsRasterizationSettings(
    image_size=(256,256), 
    radius = 0.005,
    points_per_pixel = 1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
renderer = PulsarPointsRenderer(
    rasterizer=rasterizer,
).cuda()
#for i, (images, _) in enumerate(distr_dl):
# for i, (images, _) in enumerate(dl):
for i, images in enumerate(dl):
    images = images.cuda()
    if len(images.shape) == 2:
        images = images.unsqueeze(0)

    codebook_indices = distr_vae.get_codebook_indices(images)
    s1 = codebook_indices[0]
    s2 = codebook_indices[1]
    s3 = torch.cat((s1[:64],s2[:64]))
    s4 = torch.cat((s1[:32],s2[:128-32]))
    s5 = torch.cat((s1[:128-32],s2[:32]))
    scat = torch.cat((s1,s2))
    idx_arr = np.arange(scat.shape[0])
    np.random.shuffle(idx_arr)
    s6 = scat[idx_arr[:128]]
    s7 = scat[idx_arr[128:]]
    np.random.shuffle(idx_arr)
    s8 = scat[idx_arr[:128]]
    s9 = scat[idx_arr[128:]]
    s1 = s1.unsqueeze(0)
    s2 = s2.unsqueeze(0)
    s3 = s3.unsqueeze(0)
    s4 = s4.unsqueeze(0)
    s5 = s5.unsqueeze(0)
    s6 = s6.unsqueeze(0)
    s7 = s7.unsqueeze(0)
    s8 = s8.unsqueeze(0)
    s9 = s9.unsqueeze(0)
    codes = torch.cat((s1,s2,s3,s4,s5,s6,s7,s8,s9),dim=0)
    pcs = distr_vae.decode(codes)
    # for analysis the learned embedding
    #for k in range(31): 
    #    codes = (torch.ones([16,128]).cuda() * torch.arange(16*k, 16*(k+1)).unsqueeze(1).cuda()).long()
    #    pcs = distr_vae.decode(codes)
    #    for j in range(pcs.shape[0]):
    #        save_ply(os.path.join(save_dir,'%03d'%count+'internal_%d.ply'%(j+k*16)), pcs[j])

    #codes = torch.floor(torch.rand([16, 128]) * 128).long().cuda()
    #pcs = distr_vae.decode(codes)
    #for j in range(pcs.shape[0]):
    #    save_ply(os.path.join(save_dir,'%03d'%count+'random_%d.ply'%(j)), pcs[j])
    

    save_ply(os.path.join(save_dir,'%04d'%count+'ori_0.ply'), images[0])
    save_ply(os.path.join(save_dir,'%04d'%count+'ori_1.ply'), images[1])
    #render_pts_with_label(os.path.join(save_dir, '%04d'%count+'0_ori.png'), rotate_pts(images[0].cpu().numpy(), 100, 15),  torch.randint(3,[images[0].shape[0],1]).numpy())
    #render_pts_with_label(os.path.join(save_dir, '%04d'%count+'1_ori.png'), rotate_pts(images[1].cpu().numpy(), 100, 15),  torch.randint(3,[images[1].shape[0],1]).numpy())
    for j in range(pcs.shape[0]):
            save_ply(os.path.join(save_dir,'%04d'%count+'inter_%d.ply'%j), pcs[j])
            #render_pts_with_label(os.path.join(save_dir, '%04d'%count+'_inter_%d.png'%j), rotate_pts(pcs[j].detach().cpu().numpy(), 100, 12),  torch.randint(3,[pcs[j].shape[0],1]).numpy(), point_size=16)
    count += 1
    if i == 5:
        embed()

    #from sklearn.manifold import TSNE
    #from matplotlib import cm

    #tsne = TSNE(2, verbose=1)
    #tsne_proj = tsne.fit_transform(list(distr_vae.quantize_layer._embedding.parameters())[0].detach().cpu().numpy())
    ## Plot those points as a scatter plot and label them based on the pred labels
    #cmap = cm.get_cmap('tab20')
    #fig, ax = plt.subplots(figsize=(16,16))
    #num_categories = 512
    #for lab in range(num_categories):
    #    indices = test_predictions==lab
    #    ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = lab ,alpha=0.5)
    #ax.legend(fontsize='large', markerscale=2)
    #plt.show()
    
    # if using_deepspeed:
    #     # Gradients are automatically zeroed after the step
    #     distr_vae.backward(loss)
    #     distr_vae.step()
    # else:
    #     distr_opt.zero_grad()
    #     loss.backward()
    #     distr_opt.step()
    # if not using_deepspeed_sched:
    #     distr_sched.step()

    # logs = {}
    # if args.testae:
    #     for j in range(images.shape[0]):
    #         save_ply(os.path.join(save_dir,'%04d'%count+'_ori.ply'), images[j])
    #         save_ply(os.path.join(save_dir,'%04d'%count+'_recons.ply'), recons[j].reshape(-1,3))
    #     count+=1

    # # if i % 100 == 0:
    # else:
    #    if distr_backend.is_root_worker():
    #        k = NUM_IMAGES_SAVE

    #        with torch.no_grad():
    #            codes = vae.get_codebook_indices(images[:k])
    #            hard_recons = vae.decode(codes)

    #        images, recons = map(lambda t: t[:k], (images, recons))
    #        images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
    #        # images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))
    #        for j in range(images.shape[0]):
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_ori.ply'), images[j])
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_recons.ply'), recons[j].reshape(-1,3))
    #            save_ply(os.path.join(save_dir,'%04d'%count+'_hardrecons.ply'), hard_recons[j].reshape(-1,3))
    #            # save_ply(os.path.join('./vae_outputs','test'+args.save_name,'%04d'%count+'_our.ply'), recons[j].reshape(10000,3))
    #            count+=1